{
 "cells": [
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Homework 4\n",
    "\n",
    "Today we'll start by reproducing the DQN and then try improving it with the tricks we learned on the lecture:\n",
    "* Target networks\n",
    "* Double q-learning\n",
    "* Prioritized experience replay\n",
    "* Dueling DQN\n",
    "* Bootstrap DQN"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "import matplotlib.pyplot as plt",
    "\n",
    "import numpy as np",
    "\n",
    "%matplotlib inline",
    "\n",
    "\n",
    "\n",
    "# If you are running on a server, launch xvfb to record game videos",
    "\n",
    "# Please make sure you have xvfb installed",
    "\n",
    "import os",
    "\n",
    "if type(os.environ.get(\"DISPLAY\")) is not str or len(os.environ.get(\"DISPLAY\")) == 0:",
    "\n",
    "    !bash ../xvfb start",
    "\n",
    "    os.environ['DISPLAY'] = ':1'"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# Processing game image (2 pts)\n",
    "\n",
    "Raw atari images are large, 210x160x3 by default. However, we don't need that level of detail in order to learn them.\n",
    "\n",
    "We can thus save a lot of time by preprocessing game image, including\n",
    "* Resizing to a smaller shape\n",
    "* Converting to grayscale\n",
    "* Cropping irrelevant image parts"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "from gym.core import ObservationWrapper",
    "\n",
    "from gym.spaces import Box",
    "\n",
    "\n",
    "from scipy.misc import imresize",
    "\n",
    "\n",
    "\n",
    "class PreprocessAtari(ObservationWrapper):",
    "\n",
    "    def __init__(self, env):",
    "\n",
    "        \"\"\"A gym wrapper that crops, scales image into the desired shapes and optionally grayscales it.\"\"\"",
    "\n",
    "        ObservationWrapper.__init__(self, env)",
    "\n",
    "\n",
    "        self.img_size = (64, 64)",
    "\n",
    "        self.observation_space = Box(0.0, 1.0, self.img_size)",
    "\n",
    "\n",
    "    def observation(self, img):",
    "\n",
    "        \"\"\"what happens to each observation\"\"\"",
    "\n",
    "\n",
    "        # Here's what you need to do:",
    "\n",
    "        #  * crop image, remove irrelevant parts",
    "\n",
    "        #  * resize image to self.img_size",
    "\n",
    "        #     (use imresize imported above or any library you want,",
    "\n",
    "        #      e.g. opencv, skimage, PIL, keras)",
    "\n",
    "        #  * cast image to grayscale",
    "\n",
    "        #  * convert image pixels to (0,1) range, float32 type",
    "\n",
    "\n",
    "        <Your code here >",
    "\n",
    "        return < ... >"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "import gym",
    "\n",
    "\n",
    "\n",
    "def make_env():",
    "\n",
    "    env = gym.make(\"KungFuMasterDeterministic-v0\")  # create raw env",
    "\n",
    "    return PreprocessAtari(env)  # apply your wrapper",
    "\n",
    "\n",
    "\n",
    "# spawn game instance for tests",
    "\n",
    "env = make_env()",
    "\n",
    "\n",
    "observation_shape = env.observation_space.shape",
    "\n",
    "n_actions = env.action_space.n",
    "\n",
    "\n",
    "obs = env.reset()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# test observation",
    "\n",
    "assert obs.shape == observation_shape",
    "\n",
    "assert obs.dtype == 'float32'",
    "\n",
    "assert len(np.unique(obs)) > 2, \"your image must not be binary\"",
    "\n",
    "assert 0 <= np.min(obs) and np.max(",
    "\n",
    "    obs) <= 1, \"convert image pixels to (0,1) range\"",
    "\n",
    "\n",
    "print \"Formal tests seem fine. Here's an example of what you'll get.\"",
    "\n",
    "\n",
    "plt.title(\"what your network gonna see\")",
    "\n",
    "plt.imshow(obs, interpolation='none', cmap='gray')"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "plt.figure(figsize=[12, 12])",
    "\n",
    "env.reset()",
    "\n",
    "for i in range(16):",
    "\n",
    "    for _ in range(10):",
    "\n",
    "        new_obs = env.step(env.action_space.sample())[0]",
    "\n",
    "    plt.subplot(4, 4, i+1)",
    "\n",
    "    plt.imshow(new_obs, interpolation='none', cmap='gray')"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# dispose of the game instance",
    "\n",
    "del env"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# Building a DQN (2 pts)\n",
    "Here we define a simple agent that maps game images into Qvalues using simple convolutional neural network.\n",
    "\n",
    "![scheme](https://s18.postimg.cc/gbmsq6gmx/dqn_scheme.png)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# setup theano/lasagne. Prefer GPU. Fallback to CPU (will print warning)",
    "\n",
    "%env THEANO_FLAGS = floatX = float32",
    "\n",
    "\n",
    "import theano",
    "\n",
    "import lasagne",
    "\n",
    "from lasagne.layers import *",
    "\n",
    "from theano import tensor as T"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# observation",
    "\n",
    "observation_layer = InputLayer(",
    "\n",
    "    (None,)+observation_shape)  # game image, [batch,64,64]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# 4-tick window over images",
    "\n",
    "from agentnet.memory import WindowAugmentation",
    "\n",
    "\n",
    "# window size [batch,4,64,64]",
    "\n",
    "prev_wnd = InputLayer((None, 4)+observation_shape)",
    "\n",
    "\n",
    "new_wnd = WindowAugmentation( < current observation layer > , prev_wnd)",
    "\n",
    "\n",
    "# if you changed img size, remove assert",
    "\n",
    "assert new_wnd.output_shape == (None, 4, 64, 64)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "from lasagne.nonlinearities import elu, tanh, softmax, rectify",
    "\n",
    "\n",
    "<network body, growing from new_wnd. several conv layers or something similar would do >",
    "\n",
    "\n",
    "dense = <final dense layer with 256 neurons >"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# qvalues layer",
    "\n",
    "qvalues_layer = <a dense layer that predicts q-values >",
    "\n",
    "\n",
    "assert qvalues_layer.nonlinearity is not rectify"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# sample actions proportionally to policy_layer",
    "\n",
    "from agentnet.resolver import EpsilonGreedyResolver",
    "\n",
    "action_layer = EpsilonGreedyResolver(qvalues_layer)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Define agent\n",
    "Here you will need to declare how your agent works\n",
    "\n",
    "* `observation_layers` and `action_layers` are the input and output of agent in MDP.\n",
    "* `policy_estimators` must contain whatever you need for training\n",
    " * In our case, that's `qvalues_layer`, but you'll need to add more when implementing target network.\n",
    "* agent_states contains our frame buffer. \n",
    " * The code `{new_wnd:prev_wnd}` reads as \"`new_wnd becomes prev_wnd next turn`\""
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "from agentnet.agent import Agent",
    "\n",
    "# agent",
    "\n",
    "agent = Agent(observation_layers= < ... >,",
    "\n",
    "              policy_estimators= < ... >,",
    "\n",
    "              action_layers= < ... >,",
    "\n",
    "              agent_states={new_wnd: prev_wnd},)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# Create and manage a pool of atari sessions to play with\n",
    "\n",
    "* To make training more stable, we shall have an entire batch of game sessions each happening independent of others\n",
    "* Why several parallel agents help training: http://arxiv.org/pdf/1602.01783v1.pdf\n",
    "* Alternative approach: store more sessions: https://www.cs.toronto.edu/~vmnih/docs/dqn.pdf"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "from agentnet.experiments.openai_gym.pool import EnvPool",
    "\n",
    "\n",
    "pool = EnvPool(agent, make_env, n_games=16)  # 16 parallel game sessions"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "% % time",
    "\n",
    "# interact for 7 ticks",
    "\n",
    "_, action_log, reward_log, _, _, _ = pool.interact(5)",
    "\n",
    "\n",
    "print('actions:')",
    "\n",
    "print(action_log[0])",
    "\n",
    "print(\"rewards\")",
    "\n",
    "print(reward_log[0])"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# load first sessions (this function calls interact and remembers sessions)",
    "\n",
    "SEQ_LENGTH = 10  # sub-session length",
    "\n",
    "pool.update(SEQ_LENGTH)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# Q-learning\n",
    "\n",
    "We train our agent based on sessions it has played in `pool.update(SEQ_LENGTH)`\n",
    "\n",
    "To do so, we first obtain sequences of observations, rewards, actions, q-values, etc.\n",
    "\n",
    "Actions and rewards have shape `[n_games,seq_length]`, q-values are `[n_games,seq_length,n_actions]`"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# get agent's Qvalues obtained via experience replay",
    "\n",
    "replay = pool.experience_replay",
    "\n",
    "\n",
    "actions, rewards, is_alive = replay.actions[0], replay.rewards, replay.is_alive",
    "\n",
    "\n",
    "_, _, _, _, qvalues = agent.get_sessions(",
    "\n",
    "    replay,",
    "\n",
    "    session_length=SEQ_LENGTH,",
    "\n",
    "    experience_replay=True,",
    "\n",
    ")",
    "\n",
    "\n",
    "assert actions.ndim == rewards.ndim == is_alive.ndim == 2, \"actions, rewards and is_alive must have shape [batch,time]\"",
    "\n",
    "assert qvalues.ndim == 3, \"q-values must have shape [batch,time,n_actions]\""
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# compute V(s) as Qvalues of best actions.",
    "\n",
    "# For homework assignment, you will need to use target net",
    "\n",
    "# or special double q-learning objective here",
    "\n",
    "\n",
    "state_values_target = <compute V(s) 2d tensor by taking T.argmax of qvalues over correct axis >",
    "\n",
    "\n",
    "assert state_values_target.eval().shape = qvalues.eval().shape[:2]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "from agentnet.learning.generic import get_n_step_value_reference",
    "\n",
    "\n",
    "# get reference Q-values via Q-learning algorithm",
    "\n",
    "reference_qvalues = get_n_step_value_reference(",
    "\n",
    "    state_values=state_values_target,",
    "\n",
    "    rewards=rewards/100.,",
    "\n",
    "    is_alive=is_alive,",
    "\n",
    "    n_steps=10,",
    "\n",
    "    gamma_or_gammas=0.99,",
    "\n",
    ")",
    "\n",
    "\n",
    "# consider it constant",
    "\n",
    "from theano.gradient import disconnected_grad",
    "\n",
    "reference_qvalues = disconnected_grad(reference_qvalues)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# get predicted Q-values for committed actions by both current and target networks",
    "\n",
    "from agentnet.learning.generic import get_values_for_actions",
    "\n",
    "action_qvalues = get_values_for_actions(qvalues, actions)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# loss for Qlearning =",
    "\n",
    "# (Q(s,a) - (r+ gamma*r' + gamma^2*r'' + ...  +gamma^10*Q(s_{t+10},a_max)))^2",
    "\n",
    "\n",
    "elwise_mse_loss = <mean squared error between action qvalues and reference qvalues >",
    "\n",
    "\n",
    "# mean over all batches and time ticks",
    "\n",
    "loss = (elwise_mse_loss*is_alive).mean()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# Since it's a single lasagne network, one can get it's weights, output, etc",
    "\n",
    "weights = <get all trainable params >",
    "\n",
    "weights"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# Compute weight updates",
    "\n",
    "updates = <your favorite optimizer >",
    "\n",
    "\n",
    "# compile train function",
    "\n",
    "train_step = theano.function([], loss, updates=updates)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# Demo run\n",
    "as usual..."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "action_layer.epsilon.set_value(0.05)",
    "\n",
    "untrained_reward = np.mean(pool.evaluate(save_path=\"./records\",",
    "\n",
    "                                         record_video=True))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# show video",
    "\n",
    "from IPython.display import HTML",
    "\n",
    "import os",
    "\n",
    "\n",
    "video_names = list(",
    "\n",
    "    filter(lambda s: s.endswith(\".mp4\"), os.listdir(\"./records/\")))",
    "\n",
    "\n",
    "HTML(\"\"\"",
    "\n",
    "<video width=\"640\" height=\"480\" controls>",
    "\n",
    "  <source src=\"{}\" type=\"video/mp4\">",
    "\n",
    "</video>",
    "\n",
    "\"\"\".format(\"./records/\"+video_names[-1]))  # this may or may not be _last_ video. Try other indices"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# Training loop"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# starting epoch",
    "\n",
    "epoch_counter = 1",
    "\n",
    "\n",
    "# full game rewards",
    "\n",
    "rewards = {}",
    "\n",
    "loss, reward_per_tick, reward = 0, 0, 0"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "from tqdm import trange",
    "\n",
    "from IPython.display import clear_output",
    "\n",
    "\n",
    "\n",
    "for i in trange(150000):",
    "\n",
    "\n",
    "    # update agent's epsilon (in e-greedy policy)",
    "\n",
    "    current_epsilon = 0.05 + 0.45*np.exp(-epoch_counter/20000.)",
    "\n",
    "    action_layer.epsilon.set_value(np.float32(current_epsilon))",
    "\n",
    "\n",
    "    # play",
    "\n",
    "    pool.update(SEQ_LENGTH)",
    "\n",
    "\n",
    "    # train",
    "\n",
    "    loss = 0.95*loss + 0.05*train_step()",
    "\n",
    "\n",
    "    if epoch_counter % 10 == 0:",
    "\n",
    "        # average reward per game tick in current experience replay pool",
    "\n",
    "        reward_per_tick = 0.95*reward_per_tick + 0.05 * \\",
    "\n",
    "            pool.experience_replay.rewards.get_value().mean()",
    "\n",
    "        print(\"iter=%i\\tepsilon=%.3f\\tloss=%.3f\\treward/tick=%.3f\" % (epoch_counter,",
    "\n",
    "                                                                      current_epsilon,",
    "\n",
    "                                                                      loss,",
    "\n",
    "                                                                      reward_per_tick))",
    "\n",
    "\n",
    "    # record current learning progress and show learning curves",
    "\n",
    "    if epoch_counter % 100 == 0:",
    "\n",
    "        action_layer.epsilon.set_value(0.05)",
    "\n",
    "        reward = 0.95*reward + 0.05*np.mean(pool.evaluate(record_video=False))",
    "\n",
    "        action_layer.epsilon.set_value(np.float32(current_epsilon))",
    "\n",
    "\n",
    "        rewards[epoch_counter] = reward",
    "\n",
    "\n",
    "        clear_output(True)",
    "\n",
    "        plt.plot(*zip(*sorted(rewards.items(), key=lambda (t, r): t)))",
    "\n",
    "        plt.show()",
    "\n",
    "\n",
    "    epoch_counter += 1",
    "\n",
    "\n",
    "\n",
    "# Time to drink some coffee!"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# Evaluating results\n",
    " * Here we plot learning curves and sample testimonials"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "import pandas as pd",
    "\n",
    "plt.plot(*zip(*sorted(rewards.items(), key=lambda k: k[0])))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "from agentnet.utils.persistence import save, load",
    "\n",
    "save(action_layer, \"pacman.pcl\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "action_layer.epsilon.set_value(0.05)",
    "\n",
    "rw = pool.evaluate(n_games=20, save_path=\"./records\", record_video=False)",
    "\n",
    "print(\"mean session score=%f.5\" % np.mean(rw))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# show video",
    "\n",
    "from IPython.display import HTML",
    "\n",
    "import os",
    "\n",
    "\n",
    "video_names = list(",
    "\n",
    "    filter(lambda s: s.endswith(\".mp4\"), os.listdir(\"./records/\")))",
    "\n",
    "\n",
    "HTML(\"\"\"",
    "\n",
    "<video width=\"640\" height=\"480\" controls>",
    "\n",
    "  <source src=\"{}\" type=\"video/mp4\">",
    "\n",
    "</video>",
    "\n",
    "\"\"\".format(\"./videos/\"+video_names[-1]))  # this may or may not be _last_ video. Try other indices"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Assignment part I (5 pts)\n",
    "\n",
    "We'll start by implementing target network to stabilize training.\n",
    "\n",
    "There are two ways to do so: \n",
    "\n",
    "\n",
    "__1)__ Manually write lasagne network, or clone it via [one of those methods](https://github.com/Lasagne/Lasagne/issues/720).\n",
    "\n",
    "You will need to implement loading weights from original network to target network.\n",
    "\n",
    "We recommend thoroughly debugging your code on simple tests before applying it in atari dqn.\n",
    "\n",
    "__2)__ Use pre-build functionality from [here](http://agentnet.readthedocs.io/en/master/modules/target_network.html)\n",
    "\n",
    "```\n",
    "from agentnet.target_network import TargetNetwork\n",
    "target_net = TargetNetwork(qvalues_layer)\n",
    "old_qvalues = target_net.output_layers\n",
    "\n",
    "#agent's policy_estimators must now become (qvalues,old_qvalues)\n",
    "\n",
    "_,_,_,_,(qvalues,old_qvalues) = agent.get_sessions(...) #replaying experience\n",
    "\n",
    "\n",
    "target_net.load_weights()#loads weights, so target network is now exactly same as main network\n",
    "\n",
    "target_net.load_weights(0.01)# w_target = 0.99*w_target + 0.01*w_new\n",
    "```"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Bonus I (2+ pts)\n",
    "\n",
    "Implement and train double q-learning.\n",
    "\n",
    "This task contains of\n",
    "* Implementing __double q-learning__ or __dueling q-learning__ or both (see tips below)\n",
    "* Training a network till convergence\n",
    "  * Full points will be awwarded if your network gets average score of >=10 (see \"evaluating results\")\n",
    "  * Higher score = more points as usual\n",
    "  * If you're running out of time, it's okay to submit a solution that hasn't converged yet and updating it when it converges. _Lateness penalty will not increase for second submission_, so submitting first one in time gets you no penalty.\n",
    "\n",
    "\n",
    "#### Tips:\n",
    "* Implementing __double q-learning__ shouldn't be a problem if you've already have target networks in place.\n",
    "  * As one option, use `get_values_for_actions(<some q-values tensor3>,<some indices>)`.\n",
    "  * You will probably need `T.argmax` to select best actions\n",
    "  * Here's an original [article](https://arxiv.org/abs/1509.06461)\n",
    "\n",
    "* __Dueling__ architecture is also quite straightforward if you have standard DQN.\n",
    "  * You will need to change network architecture, namely the q-values layer\n",
    "  * It must now contain two heads: V(s) and A(s,a), both dense layers\n",
    "  * You should then add them up via elemwise sum layer or a [custom](http://lasagne.readthedocs.io/en/latest/user/custom_layers.html) layer.\n",
    "  * Here's an [article](https://arxiv.org/pdf/1511.06581.pdf)\n",
    "  \n",
    "Here's a template for your convenience:"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "from lasagne.layers import *",
    "\n",
    "\n",
    "\n",
    "class DuelingQvaluesLayer(MergeLayer):",
    "\n",
    "    def get_output_for(self, inputs, **tags):",
    "\n",
    "        V, A = inputs",
    "\n",
    "        return < add them up: ) >",
    "\n",
    "\n",
    "    def get_output_shape_for(self, input_shapes, **tags):",
    "\n",
    "        V_shape, A_shape=input_shapes",
    "\n",
    "        assert len(",
    "\n",
    "            V_shape) == 2 and V_shape[-1] == 1, \"V layer (first param) shape must be [batch,tick,1]\"",
    "\n",
    "        return A_shape  # shape of q-values is same as predicted advantages"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# mock-up tests",
    "\n",
    "import theano.tensor as T",
    "\n",
    "v_tensor = -T.arange(10).reshape((10, 1))",
    "\n",
    "V = InputLayer((None, 1), v_tensor)",
    "\n",
    "\n",
    "a_tensor = T.arange(30).reshape((10, 3))",
    "\n",
    "A = InputLayer((None, 1), a_tensor)",
    "\n",
    "\n",
    "Q = DuelingQvaluesLayer([V, A])"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "import numpy as np",
    "\n",
    "assert np.allclose(get_output(Q).eval(), (v_tensor+a_tensor).eval())",
    "\n",
    "print(\"looks good\")"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Bonus II (5+ pts): Prioritized experience replay\n",
    "\n",
    "In this section, you're invited to implement prioritized experience replay\n",
    "\n",
    "* You will probably need to provide a custom data structure\n",
    "* Once pool.update is called, collect the pool.experience_replay.observations, actions, rewards and is_alive and store them in your data structure\n",
    "* You can now sample such transitions in proportion to the error (see [article](https://arxiv.org/abs/1511.05952)) for training.\n",
    "\n",
    "It's probably more convenient to explicitly declare inputs for \"sample observations\", \"sample actions\" and so on to plug them into q-learning.\n",
    "\n",
    "Prioritized (and even normal) experience replay should greatly reduce amount of game sessions you need to play in order to achieve good performance. \n",
    "\n",
    "While it's effect on runtime is limited for atari, more complicated envs (further in the course) will certainly benefit for it.\n",
    "\n",
    "Prioritized experience replay only supports off-policy algorithms, so pls enforce `n_steps=1` in your q-learning reference computation (default is 10)."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "language_info": {
   "name": "python",
   "pygments_lexer": "ipython3"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 1
}
